-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adds count num sequences and tokens metric #346
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks!
I added questions + typing mistakes + artifacts inline.
Doesn't change the logic.
import torch | ||
import numpy as np | ||
|
||
from fuse.eval.metrics.metrics_common import MetricPerBatchDefault | ||
|
||
|
||
class MetricCountSeqAndTokens(MetricPerBatchDefault): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question:
I'm not sure counting the sequences and tokens should be defined as metric. I don't have another suggestion it's just sounds weird :)
What do you think of that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It uses the metric mechanism, and it's ok to me that it just counts some stats.
) -> None: | ||
""" | ||
:param encoder_input: key to the encoder_input | ||
:param ignore_index: token_id to ignore (not to count), typically pad token id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be able to support a list of token ids to ignore. Unless you want to enforce the user to ignore only the PAD one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with just one to be more efficient - typically we would like to just skip the padding.
:param kwargs: additional super class arguments | ||
""" | ||
super().__init__( | ||
seq_num="seq_num", # collect log_probs - output of _count_seq_and_tokens_update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
obsolete comments in this line and the following one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
|
||
def _count_seq_and_tokens_update( | ||
batch_dict: dict, | ||
encoder_input_key: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
encoder_input_key: Union[str, None]
or
encoder_input_key: Optional[str]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's. a must. Why optional?
|
||
def _count_seq_and_tokens_compute( | ||
self, | ||
seq_num: List[np.ndarray], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_num
will be a numpy array such that each entry represents a batch? If so, how often the metrics being calculate? each epoch?
I forgot these :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each sub epoch and each entry is a batch.
self, | ||
seq_num: List[np.ndarray], | ||
token_num: List[np.ndarray], | ||
) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
returns a dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
batch_dict: dict, | ||
encoder_input_key: str, | ||
ignore_index: Optional[int] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> dict[str, Tensor]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Last comment, did you try to write a test for it? So we'll have it covered If time not permits maybe as a card on monday and we'll get to it later |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review and useful comments @SagiPolaczek
) -> None: | ||
""" | ||
:param encoder_input: key to the encoder_input | ||
:param ignore_index: token_id to ignore (not to count), typically pad token id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with just one to be more efficient - typically we would like to just skip the padding.
:param kwargs: additional super class arguments | ||
""" | ||
super().__init__( | ||
seq_num="seq_num", # collect log_probs - output of _count_seq_and_tokens_update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
|
||
def _count_seq_and_tokens_compute( | ||
self, | ||
seq_num: List[np.ndarray], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each sub epoch and each entry is a batch.
|
||
def _count_seq_and_tokens_update( | ||
batch_dict: dict, | ||
encoder_input_key: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's. a must. Why optional?
batch_dict: dict, | ||
encoder_input_key: str, | ||
ignore_index: Optional[int] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
self, | ||
seq_num: List[np.ndarray], | ||
token_num: List[np.ndarray], | ||
) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Merging it to match inner-source code. |
No description provided.